{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pytorch supervised learning of perceptual decision making task\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/neurogym/neurogym/blob/master/examples/example_neurogym_pytorch.ipynb)\n",
    "\n",
    "Pytorch-based example code for training a RNN on a perceptual decision-making task.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Installation on google colab\n",
    "\n",
    "Uncomment and execute cell below if running on google colab.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Install gymnasium\n",
    "# ! pip install gymnasium\n",
    "# # Install neurogym\n",
    "# ! git clone https://github.com/neurogym/neurogym.git\n",
    "# %cd neurogym/\n",
    "# ! pip install -e ."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dataset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "import neurogym as ngym\n",
    "\n",
    "# Environment\n",
    "task = 'PerceptualDecisionMaking-v0'\n",
    "kwargs = {'dt': 100}\n",
    "seq_len = 100\n",
    "\n",
    "# Make supervised dataset\n",
    "dataset = ngym.Dataset(task, env_kwargs=kwargs, batch_size=16,\n",
    "                       seq_len=seq_len)\n",
    "env = dataset.env\n",
    "ob_size = env.observation_space.shape[0]\n",
    "act_size = env.action_space.n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Network and Training\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "200 loss: 0.11658\n",
      "400 loss: 0.03191\n",
      "600 loss: 0.01894\n",
      "800 loss: 0.01393\n",
      "1000 loss: 0.01351\n",
      "1200 loss: 0.01275\n",
      "1400 loss: 0.01280\n",
      "1600 loss: 0.01279\n",
      "1800 loss: 0.01223\n",
      "2000 loss: 0.01239\n",
      "Finished Training\n"
     ]
    }
   ],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self, num_h):\n",
    "        super(Net, self).__init__()\n",
    "        self.lstm = nn.LSTM(ob_size, num_h)\n",
    "        self.linear = nn.Linear(num_h, act_size)\n",
    "\n",
    "    def forward(self, x):\n",
    "        out, hidden = self.lstm(x)\n",
    "        x = self.linear(out)\n",
    "        return x\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "net = Net(num_h=64).to(device)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=1e-2)\n",
    "\n",
    "running_loss = 0.0\n",
    "for i in range(2000):\n",
    "    inputs, labels = dataset()\n",
    "    inputs = torch.from_numpy(inputs).type(torch.float).to(device)\n",
    "    labels = torch.from_numpy(labels.flatten()).type(torch.long).to(device)\n",
    "\n",
    "    # zero the parameter gradients\n",
    "    optimizer.zero_grad()\n",
    "\n",
    "    # forward + backward + optimize\n",
    "    outputs = net(inputs)\n",
    "\n",
    "    loss = criterion(outputs.view(-1, act_size), labels)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    # print statistics\n",
    "    running_loss += loss.item()\n",
    "    if i % 200 == 199:\n",
    "        print('{:d} loss: {:0.5f}'.format(i + 1, running_loss / 200))\n",
    "        running_loss = 0.0\n",
    "\n",
    "print('Finished Training')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Analysis\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average performance in 200 trials\n",
      "0.83\n"
     ]
    }
   ],
   "source": [
    "perf = 0\n",
    "num_trial = 200\n",
    "for i in range(num_trial):\n",
    "    env.new_trial()\n",
    "    ob, gt = env.ob, env.gt\n",
    "    ob = ob[:, np.newaxis, :]  # Add batch axis\n",
    "    inputs = torch.from_numpy(ob).type(torch.float).to(device)\n",
    "\n",
    "    action_pred = net(inputs)\n",
    "    action_pred = action_pred.detach().numpy()\n",
    "    action_pred = np.argmax(action_pred, axis=-1)\n",
    "    perf += gt[-1] == action_pred[-1, 0]\n",
    "\n",
    "perf /= num_trial\n",
    "print('Average performance in {:d} trials'.format(num_trial))\n",
    "print(perf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
